import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jaxon.models.punit as punit
from jaxon.dsp.kernels import gauss_kernel
from jaxon.dsp.rate import spike_rate
duration = 10
# parameters for P-Unit model
punit_params = {
"cell": "2010-11-08-al-invivo-1",
"EODf": 744.66,
"a_zero": 9.450855200303527,
"delta_a": 0.0604984400793618,
"dend_tau": 0.0007742334994649,
"input_scaling": 31.363843698084207,
"mem_tau": 0.0017257848281706,
"noise_strength": 0.0124091008125932,
"ref_period": 0.0010273077926126,
"deltat": 5e-05,
"tau_a": 0.1022386553157565,
"threshold": 1,
"v_base": 0,
"v_offset": -0.390625,
"v_zero": 0,
}
cell = punit_params.pop("cell")
eodf = punit_params.pop("EODf")
params = punit.PUnitParams(**punit_params)
# parameter for kernel
sigma = 0.007
ktime = 4
fs = 1 / params.deltat
# first generate a random key for the LIF model
key = jax.random.PRNGKey(42)
keys = jax.random.split(key, 10)
time = jnp.arange(0, duration, 1 / fs)
stimulus = jnp.cos(2 * jnp.pi * eodf * time)
binary_spikes, vmem = punit.simulate(key, stimulus, params)
kernel = gauss_kernel(sigma, 1 / fs, ktime)
rate = spike_rate(binary_spikes, kernel)Punit Model
1. Motivation and implementaion
The P-Unit model is a LIF model which additionally embeds different aspects of the sensory pathway in the weakly electric fish. The input to this model in the baseline condition is a sinus with the frequency of the EOD, where the amplitude is normalized to one:
\[ S(t) = S_{EOD}(t) = \cos(2\pi f_{EOD} t) \]
The P-Units respond to amplitude changes on their carrier EOD. To stimulate the P-Unit model with gaussian white noise one has to multiply the baseline with the amplitude modulation, with a default contrast (\(c\)) of 10%.
\[ S_{am}(t) = S_{EOD}(t) + (S_{EOD}(t) \xi(t) c) \]
This stimulus passes then a threshold operation between the receptor cell and afferent (P-Unit) . Through the afferent dendrite the stimulus is low-pass filtered which is governed by the dendrite time constant \(\tau_{d}\).
\[ \tau_{d} \frac{d V_{d}}{d t} = -V_{d}+ \lfloor S(t) \rfloor_{0}^{p} \]
The resulting voltage has a scaling factor \(\alpha\) and is the input in the LIF. Another addition to the standard LIF model is an adaption current, which is subtracted for the membrane voltage.
\[ \tau_{A} \frac{d A}{d t} = - A \]
Lastly there is a refractory period, where after the membrane voltage \(V_m(t)\) crossed the threshold of \(\theta = 1\), the integration of \(V_m(t)\) is paused. The fixed input bias \(\mu\) and the noise term \(\sqrt{2D}\xi(t)\) is the same as in the standard LIF. This results in the following differential equation:
\[ \tau_{m} \frac{d V_{m}}{d t} = - V_{m} + \mu + \alpha V_{d} - A + \sqrt{2D}\xi(t) \]
3. Example
Here is a minimal example that get you started.
We can now plot the simulation result of the simulation.
import plotly.graph_objects as go
from plotly.subplots import make_subplots
fig = make_subplots(specs=[[{"secondary_y": True}]])
fig.add_scatter(x=time, y=vmem, mode="lines", name="V", secondary_y=False)
fig.add_scatter(
x=time[binary_spikes.astype(bool)],
y=vmem[binary_spikes.astype(bool)] + params.threshold,
mode="markers",
marker_size=10,
marker_color="red",
marker_symbol="arrow-down",
name="Spikes",
secondary_y=False,
)
fig.add_scatter(
x=time, y=rate, name="Rate [Hz]", secondary_y=True, marker_color="magenta", line_width=4
)
fig.update_layout(xaxis_title="Time [s]", yaxis_title="Volatage [aU]")
fig.update_yaxes(title_text="Rate [Hz]", secondary_y=True)
fig.update_xaxes(range=[0, 0.2])